{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notebook for training a BERT-BCA model\n",
    "\n",
    "This notebook can be used to train a BCA model with BERT token embedding as input. The input can be generated using BERT as service (https://github.com/hanxiao/bert-as-service). The script BERT_text_representation.py can be used to create the BERT embedding from the TOEFL essay data after running TOEFL_dataParse.py "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function, division\n",
    "\n",
    "import os\n",
    "import os.path\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import io\n",
    "import unicodedata\n",
    "import re\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "np.set_printoptions(threshold = 10000)\n",
    "import collections\n",
    "import random\n",
    "\n",
    "from tensorflow.contrib.rnn import LSTMCell as Cell #for GRU: custom implementation with normalization\n",
    "from tensorflow.python.ops.rnn import dynamic_rnn as rnn\n",
    "from tensorflow.python.ops.rnn import bidirectional_dynamic_rnn as bi_rnn\n",
    "from tensorflow.contrib.rnn import DropoutWrapper\n",
    "\n",
    "from attention import attention as attention\n",
    "from bca_ import *\n",
    "from ordloss import *\n",
    "from utils import *\n",
    "#from datautilsbca import *\n",
    "\n",
    "\n",
    "from numpy import array\n",
    "from numpy import argmax\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from scipy import stats\n",
    "from sklearn.metrics import accuracy_score\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEQUENCE_LENGTH = 40\n",
    "SEQUENCE_LENGTH_D = 25\n",
    "max_vocab = 75000\n",
    "train_split = 0.9\n",
    "BATCH_SIZE = 20\n",
    "\n",
    "# system parameters\n",
    "HIDDEN_SIZE = 150\n",
    "HIDDEN_SIZE_D = 150\n",
    "ATTENTION_SIZE = 75\n",
    "ATTENTION_SIZE_D = 50\n",
    "LAYER_1 = 500\n",
    "LAYER_2 = 250\n",
    "LAYER_3 = 100\n",
    "KEEP_PROB = 0.7\n",
    "#NUM_EPOCHS = 1  # max val_acc at __\n",
    "DELTA = 0.75"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.load('data/TOEFL/X_train_TOEFL.npy')\n",
    "y_train = np.load('data/TOEFL/y_train_TOEFL.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEQUENCE_LEN_D = SEQUENCE_LENGTH_D\n",
    "SEQUENCE_LEN = SEQUENCE_LENGTH\n",
    "tr_len = int(train_split*len(y_train))\n",
    "X_train, y_train, X_val, y_val  = X_train[:tr_len*SEQUENCE_LEN_D], y_train[:tr_len], X_train[tr_len*SEQUENCE_LEN_D:], y_train[tr_len:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "tf.set_random_seed(111)\n",
    "b_len = 768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Different placeholders\n",
    "#num_classes = y_train.shape[1]\n",
    "num_classes = 3\n",
    "batch_ph = tf.placeholder(tf.float32, [None, SEQUENCE_LENGTH, b_len])\n",
    "ind_list_ph = tf.placeholder(tf.int32, [None])\n",
    "target_ph = tf.placeholder(tf.float32, [None,num_classes])\n",
    "\n",
    "seq_len_ph = tf.placeholder(tf.int32, [None])\n",
    "seq_len_ph_d = tf.placeholder(tf.int32, [None])\n",
    "keep_prob_ph = tf.placeholder(tf.float32)\n",
    "doc_size_ph = tf.placeholder(tf.int32,[None])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "W_omega = tf.Variable(tf.random_uniform([HIDDEN_SIZE*2, HIDDEN_SIZE*2], -1.0, 1.0))\n",
    "\n",
    "# (Bi-)RNN layer(-s)\n",
    "with tf.variable_scope('sentence'):\n",
    "    fw_cell = Cell(HIDDEN_SIZE)\n",
    "    bw_cell = Cell(HIDDEN_SIZE)\n",
    "    \n",
    "    fw_cell = DropoutWrapper(fw_cell, input_keep_prob=keep_prob_ph, \n",
    "                             output_keep_prob=keep_prob_ph,state_keep_prob=keep_prob_ph,\n",
    "                             variational_recurrent=True, input_size=batch_ph.get_shape()[-1], \n",
    "                             dtype = tf.float32)\n",
    "    bw_cell = DropoutWrapper(bw_cell, input_keep_prob=keep_prob_ph, \n",
    "                             output_keep_prob=keep_prob_ph,state_keep_prob= keep_prob_ph,\n",
    "                             variational_recurrent=True, input_size=batch_ph.get_shape()[-1], \n",
    "                             dtype = tf.float32)\n",
    "    rnn_output, _ = bi_rnn(fw_cell, bw_cell, inputs=batch_ph, sequence_length=seq_len_ph, dtype=tf.float32)\n",
    "\n",
    "    \n",
    "    rnn_outputs_ = cross_attention(rnn_output, SEQUENCE_LENGTH_D, seq_len_ph, BATCH_SIZE, W_omega)\n",
    "    attention_output_, alphas_ = attention(rnn_outputs_ , ATTENTION_SIZE, seq_len_ph, return_alphas = True)\n",
    "    attention_output_ = tf.reshape(attention_output_,[BATCH_SIZE, -1, HIDDEN_SIZE*2*3])\n",
    "    \n",
    "with tf.variable_scope('document'):\n",
    "    fw_cell_d = Cell(HIDDEN_SIZE_D)\n",
    "    bw_cell_d = Cell(HIDDEN_SIZE_D)\n",
    "    \n",
    "    fw_cell_d = DropoutWrapper(fw_cell_d, input_keep_prob=keep_prob_ph, \n",
    "                             output_keep_prob=keep_prob_ph,state_keep_prob=keep_prob_ph,\n",
    "                             variational_recurrent=True, input_size=attention_output_.get_shape()[-1], \n",
    "                             dtype = tf.float32)\n",
    "    bw_cell_d = DropoutWrapper(bw_cell_d, input_keep_prob=keep_prob_ph, \n",
    "                             output_keep_prob=keep_prob_ph,state_keep_prob= keep_prob_ph,\n",
    "                             variational_recurrent=True, input_size=attention_output_.get_shape()[-1], \n",
    "                             dtype = tf.float32)\n",
    "    rnn_outputs_d, _ = bi_rnn(fw_cell_d, bw_cell_d, inputs=attention_output_, \n",
    "                              sequence_length=seq_len_ph_d, dtype=tf.float32)\n",
    "    \n",
    "    attention_output_d, alphas_d = attention(rnn_outputs_d, ATTENTION_SIZE_D, seq_len_ph_d, return_alphas=True)\n",
    "\n",
    "# Dropout\n",
    "drop = tf.nn.dropout(attention_output_d, keep_prob_ph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ordinal = True\n",
    "if ordinal:\n",
    "    # For ordinal regression, same weights for each class\n",
    "    W = tf.Variable(tf.truncated_normal([drop.get_shape()[1].value], stddev=0.1))\n",
    "    W_ = tf.transpose(tf.reshape(tf.tile(W,[num_classes - 1]),[num_classes - 1, drop.get_shape()[1].value]))\n",
    "    b = tf.Variable(tf.cast(tf.range(num_classes - 1), dtype = tf.float32))\n",
    "    y_hat_ = tf.nn.xw_plus_b(drop, tf.negative(W_), b)\n",
    "\n",
    "    # Predicted labels and logits\n",
    "    y_preds, logits = preds(y_hat_,BATCH_SIZE)\n",
    "    y_true = tf.argmax(target_ph, axis = 1)\n",
    "\n",
    "    # Ordinal loss\n",
    "    loss = ordloss_m(y_hat_, target_ph, BATCH_SIZE)\n",
    "    c = stats.spearmanr\n",
    "    str_score = \"Spearman rank:\"\n",
    "    \n",
    "# Calculate and clip gradients\n",
    "max_gradient_norm = 5\n",
    "lr = 1e-4\n",
    "params = tf.trainable_variables()\n",
    "gradients = tf.gradients(loss, params)\n",
    "clipped_gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm)\n",
    "optimizer_ = tf.train.AdamOptimizer(learning_rate=lr)\n",
    "optimizer = optimizer_.apply_gradients(\n",
    "    zip(clipped_gradients, params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_counter = 0\n",
    "val_counter = []\n",
    "config = tf.ConfigProto(inter_op_parallelism_threads=24,\n",
    "                        intra_op_parallelism_threads=24)\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config = config)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Main training task\n",
    "train_batch_generator = batch_generator(X_train, y_train, BATCH_SIZE, seq_len = SEQUENCE_LENGTH_D)\n",
    "val_batch_generator = batch_generator(X_val, y_val, BATCH_SIZE, seq_len = SEQUENCE_LENGTH_D)\n",
    "\n",
    "train_accuracy = []\n",
    "val_accuracy = []\n",
    "val_counter = []\n",
    "val_count = 50\n",
    "NUM_EPOCHS = 60\n",
    "doc_size_np = np.array([0]*SEQUENCE_LENGTH_D)\n",
    "batch_counter = 0\n",
    "\n",
    "loss_train = 0\n",
    "sq_l = np.array([SEQUENCE_LENGTH]*BATCH_SIZE*SEQUENCE_LENGTH_D)\n",
    "seq_len_d = [SEQUENCE_LENGTH_D]*BATCH_SIZE               \n",
    "seq_len_d = np.array(seq_len_d)\n",
    "\n",
    "print('Training on TOEFL data')\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    print(\"epoch: {}\\t\".format(epoch), end=\"\")\n",
    "\n",
    "    # Training\n",
    "    num_batches = X_train.shape[0] // (BATCH_SIZE*SEQUENCE_LENGTH_D)\n",
    "    true = []\n",
    "    ypreds = []\n",
    "\n",
    "    for bx in range(num_batches):\n",
    "        batch_counter += 1\n",
    "        x_batch, y_batch = next(train_batch_generator)\n",
    "        \n",
    "        y_preds_, loss_tr,  _  = sess.run([y_preds, loss,  optimizer],\n",
    "                                   feed_dict={batch_ph: x_batch,\n",
    "                                              target_ph: y_batch,\n",
    "                                              seq_len_ph: sq_l,\n",
    "                                              seq_len_ph_d: seq_len_d,\n",
    "                                              doc_size_ph: doc_size_np,\n",
    "                                              keep_prob_ph: KEEP_PROB})\n",
    "        loss_train = loss_tr * DELTA + loss_train * (1 - DELTA)\n",
    "        ypreds.extend(y_preds_)\n",
    "        t = np.argmax(y_batch, axis = 1)\n",
    "        true.extend(t)\n",
    "\n",
    "        sp = c(y_preds_,t)\n",
    "        if ordinal: \n",
    "            sp = sp[0]\n",
    "        train_accuracy.append(sp)\n",
    "        \n",
    "        \n",
    "        #testing on the validation set            \n",
    "        if batch_counter%val_count == 0:\n",
    "            val_counter.append(batch_counter)\n",
    "            x_batch, y_batch = next(val_batch_generator)\n",
    "            \n",
    "            y_preds_,loss_t = sess.run([y_preds,loss],\n",
    "                          feed_dict={batch_ph: x_batch,\n",
    "                                target_ph: y_batch,\n",
    "                                seq_len_ph: sq_l,\n",
    "                                seq_len_ph_d: seq_len_d,\n",
    "                                doc_size_ph: doc_size_np,\n",
    "                                keep_prob_ph: 1.0})\n",
    "            ypreds.extend(y_preds_)\n",
    "            t = np.argmax(y_batch, axis = 1)\n",
    "            true.extend(t)\n",
    "\n",
    "            sp = c(y_preds_,t)\n",
    "            if ordinal: \n",
    "                sp = sp[0]\n",
    "            val_accuracy.append(sp)\n",
    "            #saver.save(sess, MODEL_PATH, global_step = batch_counter)\n",
    "\n",
    "    print('training loss: ' + str(loss_train))\n",
    "    spr = c(true, ypreds)\n",
    "    if ordinal:\n",
    "        spr = spr[0]\n",
    "    print('Training '+ str_score + str(spr))\n",
    "    print('Val ' + str(np.mean(val_accuracy)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def zero_pad_test_(X, seq_len_div,b_len = 768):\n",
    "    if (len(X)%seq_len_div) == 0:\n",
    "        return np.array([x for x in X])\n",
    "    diff = seq_len_div - (len(X)%seq_len_div)\n",
    "    return np.concatenate((np.array([x for x in X]),np.zeros((diff,len(X[0]),b_len))), axis = 0)\n",
    "\n",
    "def zero_pad_test(X, seq_len_div,b_len = 768):\n",
    "    if (len(X)%seq_len_div) == 0:\n",
    "        return np.array([x for x in X])\n",
    "    diff = seq_len_div - (len(X)%seq_len_div)\n",
    "    return np.concatenate((np.array([x for x in X]),np.zeros((diff,len(X[0])))), axis = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = np.load('data/TOEFL/X_test_TOEFL.npy')\n",
    "y_test = np.load('data/TOEFL/y_test_TOEFL.npy')\n",
    "X_test = zero_pad_test_(X_test, BATCH_SIZE*SEQUENCE_LENGTH_D)\n",
    "y_test = zero_pad_test(y_test, BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batch_generator = batch_generator(X_test, y_test, BATCH_SIZE, seq_len = SEQUENCE_LENGTH_D, shuffle = False)\n",
    "#testing on the test set\n",
    "num_batches = X_test.shape[0] // (BATCH_SIZE*SEQUENCE_LENGTH_D)\n",
    "true = []\n",
    "ypreds = []\n",
    "\n",
    "for bx in range(num_batches):\n",
    "    x_batch, y_batch = next(test_batch_generator)\n",
    "    y_preds_= sess.run([y_preds],\n",
    "                  feed_dict={batch_ph: x_batch,\n",
    "                        target_ph: y_batch,\n",
    "                        seq_len_ph: sq_l,\n",
    "                        seq_len_ph_d: seq_len_d,\n",
    "                        doc_size_ph: doc_size_np,\n",
    "                        keep_prob_ph: 1.0})\n",
    "    ypreds.extend(y_preds_)\n",
    "    t = np.argmax(y_batch, axis = 1)\n",
    "    true.extend(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#len(ypreds[0])\n",
    "ypreds = [j for sub in ypreds for j in sub]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_test_len = len(y_test)\n",
    "true = true[:y_test_len]\n",
    "ypreds = ypreds[:y_test_len]\n",
    "\n",
    "spr = c(true, ypreds)\n",
    "\n",
    "if ordinal:\n",
    "    spr = spr[0]\n",
    "print('Test set '+ str_score + str(spr))\n",
    "\n",
    "rank = stats.spearmanr\n",
    "print('sp rho')\n",
    "print(rank(true, ypreds))\n",
    "\n",
    "from sklearn.metrics import cohen_kappa_score as kappa\n",
    "print('qwk')\n",
    "print(kappa(true, ypreds, weights=\"quadratic\"))\n",
    "\n",
    "from scipy.stats import pearsonr\n",
    "print('pearson')\n",
    "print(pearsonr(true,ypreds))\n",
    "\n",
    "print('kappa')\n",
    "print(kappa(true, ypreds, weights=None))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
